#  Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""SQLModel implementation of model tables."""

from typing import TYPE_CHECKING, Any, List, Optional, Sequence, cast
from uuid import UUID, uuid4

from pydantic import ConfigDict
from sqlalchemy import (
    BOOLEAN,
    INTEGER,
    TEXT,
    Column,
    UniqueConstraint,
)
from sqlalchemy.orm import joinedload, object_session, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlmodel import Field, Relationship, desc, select

from zenml.enums import (
    MetadataResourceTypes,
    TaggableResourceTypes,
    VisualizationResourceTypes,
)
from zenml.models import (
    BaseResponseMetadata,
    ModelRequest,
    ModelResponse,
    ModelResponseBody,
    ModelResponseMetadata,
    ModelResponseResources,
    ModelUpdate,
    ModelVersionArtifactRequest,
    ModelVersionArtifactResponse,
    ModelVersionArtifactResponseBody,
    ModelVersionPipelineRunRequest,
    ModelVersionPipelineRunResponse,
    ModelVersionPipelineRunResponseBody,
    ModelVersionRequest,
    ModelVersionResponse,
    ModelVersionResponseBody,
    ModelVersionResponseMetadata,
    ModelVersionResponseResources,
    Page,
)
from zenml.utils.time_utils import utc_now
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
from zenml.zen_stores.schemas.curated_visualization_schemas import (
    CuratedVisualizationSchema,
)
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.project_schemas import ProjectSchema
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.tag_schemas import TagSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.utils import (
    RunMetadataInterface,
    get_page_from_list,
    jl_arg,
)

if TYPE_CHECKING:
    from zenml.zen_stores.schemas import ServiceSchema, StepRunSchema


class ModelSchema(NamedSchema, table=True):
    """SQL Model for model."""

    __tablename__ = "model"
    __table_args__ = (
        UniqueConstraint(
            "name",
            "project_id",
            name="unique_model_name_in_project",
        ),
    )

    project_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ProjectSchema.__tablename__,
        source_column="project_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    project: "ProjectSchema" = Relationship(back_populates="models")

    user_id: Optional[UUID] = build_foreign_key_field(
        source=__tablename__,
        target=UserSchema.__tablename__,
        source_column="user_id",
        target_column="id",
        ondelete="SET NULL",
        nullable=True,
    )
    user: Optional["UserSchema"] = Relationship(back_populates="models")

    license: str = Field(sa_column=Column(TEXT, nullable=True))
    description: str = Field(sa_column=Column(TEXT, nullable=True))
    audience: str = Field(sa_column=Column(TEXT, nullable=True))
    use_cases: str = Field(sa_column=Column(TEXT, nullable=True))
    limitations: str = Field(sa_column=Column(TEXT, nullable=True))
    trade_offs: str = Field(sa_column=Column(TEXT, nullable=True))
    ethics: str = Field(sa_column=Column(TEXT, nullable=True))
    save_models_to_registry: bool = Field(
        sa_column=Column(BOOLEAN, nullable=False)
    )
    tags: List["TagSchema"] = Relationship(
        sa_relationship_kwargs=dict(
            primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
            secondary="tag_resource",
            secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
            order_by="TagSchema.name",
            overlaps="tags",
        ),
    )
    model_versions: List["ModelVersionSchema"] = Relationship(
        back_populates="model",
        sa_relationship_kwargs={"cascade": "delete"},
    )
    visualizations: List["CuratedVisualizationSchema"] = Relationship(
        sa_relationship_kwargs=dict(
            primaryjoin=(
                "and_(CuratedVisualizationSchema.resource_type"
                f"=='{VisualizationResourceTypes.MODEL.value}', "
                "foreign(CuratedVisualizationSchema.resource_id)==ModelSchema.id)"
            ),
            overlaps="visualizations",
            cascade="delete",
            order_by="CuratedVisualizationSchema.display_order",
        ),
    )

    @classmethod
    def get_query_options(
        cls,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> Sequence[ExecutableOption]:
        """Get the query options for the schema.

        Args:
            include_metadata: Whether metadata will be included when converting
                the schema to a model.
            include_resources: Whether resources will be included when
                converting the schema to a model.
            **kwargs: Keyword arguments to allow schema specific logic

        Returns:
            A list of query options.
        """
        options = []

        if include_resources:
            options.extend(
                [
                    joinedload(jl_arg(ModelSchema.user)),
                    # joinedload(jl_arg(ModelSchema.tags)),
                    selectinload(jl_arg(ModelSchema.visualizations)),
                ]
            )

        return options

    @property
    def latest_version(self) -> Optional["ModelVersionSchema"]:
        """Fetch the latest version for this model.

        Raises:
            RuntimeError: If no session for the schema exists.

        Returns:
            The latest version for this model.
        """
        if session := object_session(self):
            return (
                session.execute(
                    select(ModelVersionSchema)
                    .where(ModelVersionSchema.model_id == self.id)
                    .order_by(desc(ModelVersionSchema.number))
                    .limit(1)
                )
                .scalars()
                .one_or_none()
            )
        else:
            raise RuntimeError(
                "Missing DB session to fetch latest version for model."
            )

    @classmethod
    def from_request(cls, model_request: ModelRequest) -> "ModelSchema":
        """Convert an `ModelRequest` to an `ModelSchema`.

        Args:
            model_request: The request model to convert.

        Returns:
            The converted schema.
        """
        return cls(
            name=model_request.name,
            project_id=model_request.project,
            user_id=model_request.user,
            license=model_request.license,
            description=model_request.description,
            audience=model_request.audience,
            use_cases=model_request.use_cases,
            limitations=model_request.limitations,
            trade_offs=model_request.trade_offs,
            ethics=model_request.ethics,
            save_models_to_registry=model_request.save_models_to_registry,
        )

    def to_model(
        self,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> ModelResponse:
        """Convert an `ModelSchema` to an `ModelResponse`.

        Args:
            include_metadata: Whether the metadata will be filled.
            include_resources: Whether the resources will be filled.
            **kwargs: Keyword arguments to allow schema specific logic


        Returns:
            The created `ModelResponse`.
        """
        metadata = None
        if include_metadata:
            metadata = ModelResponseMetadata(
                license=self.license,
                description=self.description,
                audience=self.audience,
                use_cases=self.use_cases,
                limitations=self.limitations,
                trade_offs=self.trade_offs,
                ethics=self.ethics,
                save_models_to_registry=self.save_models_to_registry,
            )

        resources = None
        if include_resources:
            if latest_version := self.latest_version:
                latest_version_name = latest_version.name
                latest_version_id = latest_version.id
            else:
                latest_version_name = None
                latest_version_id = None

            resources = ModelResponseResources(
                user=self.user.to_model() if self.user else None,
                tags=[tag.to_model() for tag in self.tags],
                latest_version_name=latest_version_name,
                latest_version_id=latest_version_id,
                visualizations=[
                    visualization.to_model(
                        include_metadata=False,
                        include_resources=False,
                    )
                    for visualization in self.visualizations
                ],
            )

        body = ModelResponseBody(
            user_id=self.user_id,
            project_id=self.project_id,
            created=self.created,
            updated=self.updated,
        )

        return ModelResponse(
            id=self.id,
            name=self.name,
            body=body,
            metadata=metadata,
            resources=resources,
        )

    def update(
        self,
        model_update: ModelUpdate,
    ) -> "ModelSchema":
        """Updates a `ModelSchema` from a `ModelUpdate`.

        Args:
            model_update: The `ModelUpdate` to update from.

        Returns:
            The updated `ModelSchema`.
        """
        for field, value in model_update.model_dump(
            exclude_unset=True, exclude_none=True
        ).items():
            if field in ["add_tags", "remove_tags"]:
                # Tags are handled separately
                continue
            setattr(self, field, value)
        self.updated = utc_now()
        return self


class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
    """SQL Model for model version."""

    __tablename__ = MODEL_VERSION_TABLENAME
    __table_args__ = (
        # We need three unique constraints here:
        # - The first to ensure that each model version for a
        #   model has a unique version number
        # - The second one to ensure that explicit names given by
        #   users are unique
        # - The third one to ensure that a pipeline run only produces a single
        #   auto-incremented version per model
        UniqueConstraint(
            "number",
            "model_id",
            name="unique_version_number_for_model_id",
        ),
        UniqueConstraint(
            "name",
            "model_id",
            name="unique_version_for_model_id",
        ),
        UniqueConstraint(
            "model_id",
            "producer_run_id_if_numeric",
            name="unique_numeric_version_for_pipeline_run",
        ),
    )

    project_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ProjectSchema.__tablename__,
        source_column="project_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    project: "ProjectSchema" = Relationship(back_populates="model_versions")

    user_id: Optional[UUID] = build_foreign_key_field(
        source=__tablename__,
        target=UserSchema.__tablename__,
        source_column="user_id",
        target_column="id",
        ondelete="SET NULL",
        nullable=True,
    )
    user: Optional["UserSchema"] = Relationship(
        back_populates="model_versions"
    )

    model_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ModelSchema.__tablename__,
        source_column="model_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    model: "ModelSchema" = Relationship(back_populates="model_versions")
    tags: List["TagSchema"] = Relationship(
        sa_relationship_kwargs=dict(
            primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
            secondary="tag_resource",
            secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
            order_by="TagSchema.name",
            overlaps="tags",
        ),
    )

    services: List["ServiceSchema"] = Relationship(
        back_populates="model_version",
    )

    number: int = Field(sa_column=Column(INTEGER, nullable=False))
    description: str = Field(sa_column=Column(TEXT, nullable=True))
    stage: str = Field(sa_column=Column(TEXT, nullable=True))

    run_metadata: List["RunMetadataSchema"] = Relationship(
        sa_relationship_kwargs=dict(
            secondary="run_metadata_resource",
            primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
            secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
            overlaps="run_metadata",
        ),
    )
    pipeline_runs: List["PipelineRunSchema"] = Relationship(
        back_populates="model_version",
    )
    step_runs: List["StepRunSchema"] = Relationship(
        back_populates="model_version"
    )

    # We want to make sure each pipeline run only creates a single numeric
    # version for each model. To solve this, we need to add a unique constraint.
    # If a value of a unique constraint is NULL it is ignored and the
    # remaining values in the unique constraint have to be unique. In
    # our case however, we only want the unique constraint applied in
    # case there is a producer run and only for numeric versions. To solve this,
    # we fall back to the model version ID (which is the primary key and
    # therefore unique) in case there is no producer run or the version is not
    # numeric.
    producer_run_id_if_numeric: UUID

    # Needed for cascade deletion behavior
    artifact_links: List["ModelVersionArtifactSchema"] = Relationship(
        back_populates="model_version",
        sa_relationship_kwargs={"cascade": "delete"},
    )
    pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship(
        back_populates="model_version",
        sa_relationship_kwargs={"cascade": "delete"},
    )

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())  # type: ignore[assignment]

    @classmethod
    def get_query_options(
        cls,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> Sequence[ExecutableOption]:
        """Get the query options for the schema.

        Args:
            include_metadata: Whether metadata will be included when converting
                the schema to a model.
            include_resources: Whether resources will be included when
                converting the schema to a model.
            **kwargs: Keyword arguments to allow schema specific logic

        Returns:
            A list of query options.
        """
        options = [
            joinedload(jl_arg(ModelVersionSchema.model), innerjoin=True),
        ]

        # if include_metadata:
        #     options.extend(
        #         [
        #             joinedload(jl_arg(ModelVersionSchema.run_metadata)),
        #         ]
        #     )

        if include_resources:
            options.extend(
                [
                    joinedload(jl_arg(ModelVersionSchema.user)),
                    # joinedload(jl_arg(ModelVersionSchema.services)),
                    # joinedload(jl_arg(ModelVersionSchema.tags)),
                ]
            )

        return options

    @classmethod
    def from_request(
        cls,
        model_version_request: ModelVersionRequest,
        model_version_number: int,
        producer_run_id: Optional[UUID] = None,
    ) -> "ModelVersionSchema":
        """Convert an `ModelVersionRequest` to an `ModelVersionSchema`.

        Args:
            model_version_request: The request model version to convert.
            model_version_number: The model version number.
            producer_run_id: The ID of the producer run.

        Returns:
            The converted schema.
        """
        id_ = uuid4()
        is_numeric = str(model_version_number) == model_version_request.name

        return cls(
            id=id_,
            project_id=model_version_request.project,
            user_id=model_version_request.user,
            model_id=model_version_request.model,
            name=model_version_request.name,
            number=model_version_number,
            description=model_version_request.description,
            stage=model_version_request.stage,
            producer_run_id_if_numeric=producer_run_id
            if (producer_run_id and is_numeric)
            else id_,
        )

    def to_model(
        self,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> ModelVersionResponse:
        """Convert an `ModelVersionSchema` to an `ModelVersionResponse`.

        Args:
            include_metadata: Whether the metadata will be filled.
            include_resources: Whether the resources will be filled.
            **kwargs: Keyword arguments to allow schema specific logic


        Returns:
            The created `ModelVersionResponse`.
        """
        from zenml.models import ServiceResponse

        metadata = None
        if include_metadata:
            metadata = ModelVersionResponseMetadata(
                description=self.description,
                run_metadata=self.fetch_metadata(),
            )

        resources = None
        if include_resources:
            services = cast(
                Page[ServiceResponse],
                get_page_from_list(
                    items_list=self.services,
                    response_model=ServiceResponse,
                    include_resources=include_resources,
                    include_metadata=include_metadata,
                ),
            )
            resources = ModelVersionResponseResources(
                user=self.user.to_model() if self.user else None,
                services=services,
                tags=[tag.to_model() for tag in self.tags],
            )

        body = ModelVersionResponseBody(
            user_id=self.user_id,
            project_id=self.project_id,
            created=self.created,
            updated=self.updated,
            stage=self.stage,
            number=self.number,
            model=self.model.to_model(),
        )

        return ModelVersionResponse(
            id=self.id,
            name=self.name,
            body=body,
            metadata=metadata,
            resources=resources,
        )

    def update(
        self,
        target_stage: Optional[str] = None,
        target_name: Optional[str] = None,
        target_description: Optional[str] = None,
    ) -> "ModelVersionSchema":
        """Updates a `ModelVersionSchema` to a target stage.

        Args:
            target_stage: The stage to be updated.
            target_name: The version name to be updated.
            target_description: The version description to be updated.

        Returns:
            The updated `ModelVersionSchema`.
        """
        if target_stage is not None:
            self.stage = target_stage
        if target_name is not None:
            self.name = target_name
        if target_description is not None:
            self.description = target_description
        self.updated = utc_now()
        return self


class ModelVersionArtifactSchema(BaseSchema, table=True):
    """SQL Model for linking of Model Versions and Artifacts M:M."""

    __tablename__ = "model_versions_artifacts"

    model_version_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ModelVersionSchema.__tablename__,
        source_column="model_version_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    model_version: "ModelVersionSchema" = Relationship()
    artifact_version_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ArtifactVersionSchema.__tablename__,
        source_column="artifact_version_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    artifact_version: "ArtifactVersionSchema" = Relationship(
        back_populates="model_versions_artifacts_links"
    )

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())  # type: ignore[assignment]

    @classmethod
    def from_request(
        cls,
        model_version_artifact_request: ModelVersionArtifactRequest,
    ) -> "ModelVersionArtifactSchema":
        """Convert an `ModelVersionArtifactRequest` to a `ModelVersionArtifactSchema`.

        Args:
            model_version_artifact_request: The request link to convert.

        Returns:
            The converted schema.
        """
        return cls(
            model_version_id=model_version_artifact_request.model_version,
            artifact_version_id=model_version_artifact_request.artifact_version,
        )

    def to_model(
        self,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> ModelVersionArtifactResponse:
        """Convert an `ModelVersionArtifactSchema` to an `ModelVersionArtifactResponse`.

        Args:
            include_metadata: Whether the metadata will be filled.
            include_resources: Whether the resources will be filled.
            **kwargs: Keyword arguments to allow schema specific logic


        Returns:
            The created `ModelVersionArtifactResponseModel`.
        """
        return ModelVersionArtifactResponse(
            id=self.id,
            body=ModelVersionArtifactResponseBody(
                created=self.created,
                updated=self.updated,
                model_version=self.model_version_id,
                artifact_version=self.artifact_version.to_model(),
            ),
            metadata=BaseResponseMetadata() if include_metadata else None,
        )


class ModelVersionPipelineRunSchema(BaseSchema, table=True):
    """SQL Model for linking of Model Versions and Pipeline Runs M:M."""

    __tablename__ = "model_versions_runs"

    model_version_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=ModelVersionSchema.__tablename__,
        source_column="model_version_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    model_version: "ModelVersionSchema" = Relationship()
    pipeline_run_id: UUID = build_foreign_key_field(
        source=__tablename__,
        target=PipelineRunSchema.__tablename__,
        source_column="pipeline_run_id",
        target_column="id",
        ondelete="CASCADE",
        nullable=False,
    )
    pipeline_run: "PipelineRunSchema" = Relationship()

    # TODO: In Pydantic v2, the `model_` is a protected namespaces for all
    #  fields defined under base models. If not handled, this raises a warning.
    #  It is possible to suppress this warning message with the following
    #  configuration, however the ultimate solution is to rename these fields.
    #  Even though they do not cause any problems right now, if we are not
    #  careful we might overwrite some fields protected by pydantic.
    model_config = ConfigDict(protected_namespaces=())  # type: ignore[assignment]

    @classmethod
    def from_request(
        cls,
        model_version_pipeline_run_request: ModelVersionPipelineRunRequest,
    ) -> "ModelVersionPipelineRunSchema":
        """Convert an `ModelVersionPipelineRunRequest` to an `ModelVersionPipelineRunSchema`.

        Args:
            model_version_pipeline_run_request: The request link to convert.

        Returns:
            The converted schema.
        """
        return cls(
            model_version_id=model_version_pipeline_run_request.model_version,
            pipeline_run_id=model_version_pipeline_run_request.pipeline_run,
        )

    def to_model(
        self,
        include_metadata: bool = False,
        include_resources: bool = False,
        **kwargs: Any,
    ) -> ModelVersionPipelineRunResponse:
        """Convert an `ModelVersionPipelineRunSchema` to an `ModelVersionPipelineRunResponse`.

        Args:
            include_metadata: Whether the metadata will be filled.
            include_resources: Whether the resources will be filled.
            **kwargs: Keyword arguments to allow schema specific logic


        Returns:
            The created `ModelVersionPipelineRunResponse`.
        """
        return ModelVersionPipelineRunResponse(
            id=self.id,
            body=ModelVersionPipelineRunResponseBody(
                created=self.created,
                updated=self.updated,
                model_version=self.model_version_id,
                pipeline_run=self.pipeline_run.to_model(),
            ),
            metadata=BaseResponseMetadata() if include_metadata else None,
        )
